# This file is part of the MapProxy project.
# Copyright (C) 2010 Omniscale <http://omniscale.de>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import shutil
import tempfile
import os
import re
import sys
from glob import glob as globfunc
from contextlib import contextmanager
from lxml import etree

from mapproxy.test import mocker


class Mocker(object):
    """
    This is a base class for unit-tests that use ``mocker``. This class follows
    the xUnit naming conventions for setup and teardown methods.

    `setup` will initialize a `mocker.Mocker`. The `teardown` method
    will run ``mocker.verify()``.
    """

    def setup_method(self):
        self.mocker = mocker.Mocker()

    def expect_and_return(self, mock_call, return_val):
        """
        Register a return value for the mock call.
        :param return_val: The value mock_call should return.
        """
        self.mocker.result(return_val)

    def expect(self, mock_call):
        return mocker.expect(mock_call)

    def replay(self):
        """
        Finish mock-record phase.
        """
        self.mocker.replay()

    def mock(self, base_cls=None):
        """
        Return a new mock object.
        :param base_cls: check method signatures of the mock-calls with this
            base_cls signature (optional)
        """
        if base_cls:
            return self.mocker.mock(base_cls)
        return self.mocker.mock()

    def teardown_method(self):
        self.mocker.verify()


class TempFiles(object):
    """
    This class is a context manager for temporary files.

    >>> with TempFiles(n=2, suffix='.png') as tmp:
    ...     for f in tmp:
    ...         assert os.path.exists(f)
    >>> for f in tmp:
    ...     assert not os.path.exists(f)
    """

    def __init__(self, n=1, suffix='', no_create=False):
        self.n = n
        self.suffix = suffix
        self.no_create = no_create
        self.tmp_files = []

    def __enter__(self):
        for _ in range(self.n):
            fd, tmp_file = tempfile.mkstemp(suffix=self.suffix)
            os.close(fd)
            self.tmp_files.append(tmp_file)
            if self.no_create:
                os.remove(tmp_file)
        return self.tmp_files

    def __exit__(self, exc_type, exc_val, exc_tb):
        for tmp_file in self.tmp_files:
            if os.path.exists(tmp_file):
                os.remove(tmp_file)
        self.tmp_files = []


class TempFile(TempFiles):
    def __init__(self, suffix='', no_create=False):
        TempFiles.__init__(self, suffix=suffix, no_create=no_create)

    def __enter__(self):
        return TempFiles.__enter__(self)[0]


class TempDir:
    def __enter__(self):
        self.tmp_dir = tempfile.mkdtemp()
        return self.tmp_dir

    def __exit__(self, exc_type, exc_val, exc_tb):
        if os.path.exists(self.tmp_dir):
            shutil.rmtree(self.tmp_dir, ignore_errors=True)


class ChangeWorkingDir:
    def __init__(self, new_dir):
        self.new_dir = new_dir

    def __enter__(self):
        self.old_dir = os.getcwd()
        os.chdir(self.new_dir)

    def __exit__(self, exc_type, exc_val, exc_tb):
        os.chdir(self.old_dir)


class LogMock(object):
    log_methods = ('info', 'debug', 'warn', 'error', 'fail')

    def __init__(self, module, log_name='log'):
        self.module = module
        self.orig_logger = None
        self.logged_msgs = []

    def __enter__(self):
        self.orig_logger = self.module.log
        self.module.log = self
        return self

    def __getattr__(self, name):
        if name in self.log_methods:
            def _log(msg):
                self.logged_msgs.append((name, msg))
            return _log
        raise AttributeError("'%s' object has no attribute '%s'" %
                             (self.__class__.__name__, name))

    def assert_log(self, type, msg):
        log_type, log_msg = self.logged_msgs.pop(0)
        assert log_type == type, 'expected %s log message, but was %s' % (type, log_type)
        assert msg in log_msg.lower(), "expected string '%s' in log message '%s'" % \
            (msg, log_msg)

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.module.log = self.orig_logger


def assert_re(value, regex):
    """
    >>> assert_re('hello', 'l+')
    >>> assert_re('hello', 'l{3}')
    Traceback (most recent call last):
        ...
    AssertionError: hello ~= l{3}
    """
    match = re.search(regex, value)
    assert match is not None, '%s ~= %s' % (value, regex)


def assert_files_in_dir(dir, expected, glob=None):
    """
    assert that (only) ``expected`` files are in ``dir``.
    ``filter`` can be a globbing patter, other files are ignored if it is set.
    """
    if glob is not None:
        files = globfunc(os.path.join(dir, glob))
        files = [os.path.basename(f) for f in files]
    else:
        files = os.listdir(dir)
    files.sort()
    sorted_expected = sorted(expected)
    assert sorted_expected == files, f'{", ".join(sorted_expected)} ~= {", ".join(files)}'


def validate_with_dtd(doc, dtd_name, dtd_basedir=None):
    if dtd_basedir is None:
        dtd_basedir = os.path.join(os.path.dirname(__file__), 'schemas')

    dtd_filename = os.path.join(dtd_basedir, dtd_name)
    with open(dtd_filename, 'rb') as schema:
        dtd = etree.DTD(schema)
        if isinstance(doc, (str, bytes)):
            xml = etree.XML(doc)
        else:
            xml = doc
        is_valid = dtd.validate(xml)
        print(dtd.error_log.filter_from_errors())
        return is_valid


def validate_with_xsd(doc, xsd_name, xsd_basedir=None):
    if xsd_basedir is None:
        xsd_basedir = os.path.join(os.path.dirname(__file__), 'schemas')

    xsd_filename = os.path.join(xsd_basedir, xsd_name)

    with open(xsd_filename, 'rb') as schema:
        xsd = etree.parse(schema)
        xml_schema = etree.XMLSchema(xsd)
        if isinstance(doc, (str, bytes)):
            xml = etree.XML(doc)
        else:
            xml = doc
        is_valid = xml_schema.validate(xml)
        print(xml_schema.error_log.filter_from_errors())
        return is_valid


class XPathValidator(object):
    def __init__(self, doc):
        self.xml = etree.XML(doc)

    def assert_xpath(self, xpath, expected=None):
        assert len(self.xml.xpath(xpath)) > 0, xpath + ' does not match anything'
        if expected is not None:
            if callable(expected):
                assert expected(self.xml.xpath(xpath)[0])
            else:
                assert self.xml.xpath(xpath)[0] == expected

    def xpath(self, xpath):
        return self.xml.xpath(xpath)


def strip_whitespace(data):
    """
    >>> strip_whitespace(' <foo> bar\\n zing\\t1')
    '<foo>barzing1'
    """
    if isinstance(data, bytes):
        return re.sub(br'\s+', b'', data)
    else:
        return re.sub(r'\s+', '', data)


@contextmanager
def capture(bytes=False):
    if bytes:
        from io import BytesIO as StringIO
    else:
        from io import StringIO

    backup_stdout = sys.stdout
    backup_stderr = sys.stderr

    try:
        sys.stdout = StringIO()
        sys.stderr = StringIO()
        yield sys.stdout, sys.stderr
    except Exception as ex:
        backup_stdout.write(str(ex))
        if bytes:
            backup_stdout.write(sys.stdout.getvalue().decode('utf-8'))
            backup_stderr.write(sys.stderr.getvalue().decode('utf-8'))
        else:
            backup_stdout.write(sys.stdout.getvalue())
            backup_stderr.write(sys.stderr.getvalue())
        raise
    finally:
        sys.stdout = backup_stdout
        sys.stderr = backup_stderr


def assert_permissions(file_path, permissions):
    actual_permissions = oct(os.stat(file_path).st_mode & 0o777)
    desired_permissions = oct(int(permissions, base=8))
    assert actual_permissions == desired_permissions, f'{actual_permissions} ~= {desired_permissions}'
